from typing import Any, Literal

from pydantic import BaseModel, Field


class TaskData(BaseModel):
    name: str | None = Field(
        description="Name of the task.", default=None, examples=["Check input safety"]
    )
    run_id: str = Field(
        description="ID of the task run to pair state updates to.",
        default="",
        examples=["847c6285-8fc9-4560-a83f-4e6285809254"],
    )
    state: Literal["new", "running", "complete"] | None = Field(
        description="Current state of given task instance.",
        default=None,
        examples=["running"],
    )
    result: Literal["success", "error"] | None = Field(
        description="Result of given task instance.",
        default=None,
        examples=["running"],
    )
    data: dict[str, Any] = Field(
        description="Additional data generated by the task.",
        default={},
    )

    def completed(self) -> bool:
        return self.state == "complete"

    def completed_with_error(self) -> bool:
        return self.state == "complete" and self.result == "error"


class TaskDataStatus:
    def __init__(self) -> None:
        import streamlit as st

        self.status = st.status("")
        self.current_task_data: dict[str, TaskData] = {}

    def add_and_draw_task_data(self, task_data: TaskData) -> None:
        status = self.status
        status_str = f"Task **{task_data.name}** "
        match task_data.state:
            case "new":
                status_str += "has :blue[started]. Input:"
            case "running":
                status_str += "wrote:"
            case "complete":
                if task_data.result == "success":
                    status_str += ":green[completed successfully]. Output:"
                else:
                    status_str += ":red[ended with error]. Output:"
        status.write(status_str)
        status.write(task_data.data)
        status.write("---")
        if task_data.run_id not in self.current_task_data:
            # Status label always shows the last newly started task
            status.update(label=f"""Task: {task_data.name}""")
        self.current_task_data[task_data.run_id] = task_data
        if all(entry.completed() for entry in self.current_task_data.values()):
            # Status is "error" if any task has errored
            if any(entry.completed_with_error() for entry in self.current_task_data.values()):
                state = "error"
            # Status is "complete" if all tasks have completed successfully
            else:
                state = "complete"
        # Status is "running" until all tasks have completed
        else:
            state = "running"
        status.update(state=state)
